What is Causal Machine Learning (ML)

WHY Causal ML?

  • Causal ML models action (Treatment) on the outcome
  • Causal ML requires you to have strong domain knowledge. Since you define the causal relationships not the AI
  • Causal ML is an alternative to A/B Testing. Controlled Experiments with Large Sample is Better. However, if you had a lot of data already, Causal ML could be a great option

Let's see a generated example:

In [62]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly

plotly.offline.init_notebook_mode()

np.random.seed(42)

n_samples = 100
health_index = np.random.rand(n_samples)  # Random health index values between 0 and 1
treatment_admission = np.where(health_index < 0.4, True, False)  # Health index < 0.4 should be treated
treatment_admission = np.where(
    (health_index >= 0.4) & (health_index <= 0.6), 
    np.random.choice([True, False], n_samples, p=[0.5, 0.5]), 
    treatment_admission
)

survival_percent = np.where(
    treatment_admission, 
    np.random.uniform(0.4, 0.8, n_samples), 
    np.random.uniform(0.8, 1, n_samples)
)

data = {
    'health_index': health_index,
    'treatment_admission': treatment_admission,
    'survival_percent': survival_percent
}

df = pd.DataFrame(data)

fig = px.scatter(
    df,
    x="health_index", 
    y="survival_percent", 
    color = "treatment_admission",

)
fig.show()

Given the situation where a health_index determines the chance of living and the chance of treatment of a patient, we seldom (or potentially never) gather data for health_index < threshold and NOT treated and health_index > threshold and treated, which are necessary for A/B testing. This is where we would particularly like to approximately the unsampled areas using Causal ML.

We will be exploring Causal ML using the Infant Birth Data of 2022 from CDC. And we will be focusing on the effectiveness of Admission_NICU on Infant_Living

Causal Graph Definition

In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dowhy
from dowhy import CausalModel
import statsmodels.api as sm
import warnings
warnings.filterwarnings('ignore')
In [6]:
df = pd.read_csv("Birth_US_2022.csv")
df.columns
Out[6]:
Index(['MM', 'hhmm', 'DOW', 'Birth_Place', 'Mother_Age', 'Mother_Nativity',
       'Mother_Race', 'Marital_Status', 'Mother_Education', 'Father_Age',
       'Father_Race', 'Father_Education', 'Prior_Birth_Living',
       'Prior_Birth_Dead', 'Prior_Birth_Termination', 'Prenatal_Care (Month)',
       'Prenatal_Visit', 'WIC', 'Cig_before_Pregnant', 'Cig_1st_Trimester',
       'Cig_2nd_Trimester', 'Cig_3rd_Trimester', 'Mother_Height (Inch)',
       'Mother_BMI', 'Mother_Pre-Preg_Weight (Pound)',
       'Mother_Delivery_Weight (Pound)', 'Pre-Preg_Diabetes',
       'Gestational_Diabetes', 'Pre-Preg_Hypertension',
       'Gestational_Hypertension', 'Hypertension_Eclampsia',
       'Previous_Preterm_Birth', 'Infertility_Treatment',
       'Fertility_Enhancing', 'Asst_Reproductive', 'Previous_Cesareans',
       'Gonorrhea', 'Syphilis', 'Chlamydia', 'Hepatitis_B', 'Hepatitis_C',
       'Maternal_Transfusion', 'Perineal_Laceration', 'Ruptured_Uterus',
       'Unplanned_Hysterectomy', 'Admit_to_Intensive_Care',
       'Payment_for_Delivery', 'Infant_Sex', 'Birth_Weight (g)',
       'Limb_Reduction_Defect', 'Cleft_Lip', 'Down_Syndrome',
       'Suspected_Chromosomal_Disorder', 'Hypospadias', 'APGAR_5min',
       'Assisted_Ventil_Immediate', 'Assisted_Ventil_6h', 'Admission_NICU',
       'Surfactant', 'Newborn_Antibiotics', 'Anencephaly', 'Meningomyelocele',
       'Cyanotic_Congenital_Heart_Disease', 'Congenital_Diaphragmatic_Hernia',
       'Omphalocele', 'Gastroschisis', 'Infant_Living'],
      dtype='object')

Data Cleaning

There are many columns, let's drop some of them to make it more readable

In [7]:
try:
    df = df.loc[:, [
        "Infant_Living",
        "Admission_NICU",
        
        "Birth_Weight (g)",
        "Limb_Reduction_Defect",
        "Cleft_Lip",
        "Down_Syndrome",
        "Suspected_Chromosomal_Disorder",
        "Hypospadias",
        "APGAR_5min",
        "Gastroschisis",
        "Omphalocele",
        "Cyanotic_Congenital_Heart_Disease",
        "Congenital_Diaphragmatic_Hernia",
        "Meningomyelocele",
        "Anencephaly"
    ]]
except:
    print("Already Dropped!")
print(df.columns)
Index(['Infant_Living', 'Admission_NICU', 'Birth_Weight (g)',
       'Limb_Reduction_Defect', 'Cleft_Lip', 'Down_Syndrome',
       'Suspected_Chromosomal_Disorder', 'Hypospadias', 'APGAR_5min',
       'Gastroschisis', 'Omphalocele', 'Cyanotic_Congenital_Heart_Disease',
       'Congenital_Diaphragmatic_Hernia', 'Meningomyelocele', 'Anencephaly'],
      dtype='object')

Resampling Balance

In [9]:
df_true_sample = df[df['Infant_Living'] == "Y"].sample(n=8000, replace=False)
df_false_sample = df[df['Infant_Living'] == "N"].sample(n=8000, replace=False)
df = pd.concat([df_true_sample, df_false_sample])
df.head()
Out[9]:
Infant_Living Admission_NICU Birth_Weight (g) Limb_Reduction_Defect Cleft_Lip Down_Syndrome Suspected_Chromosomal_Disorder Hypospadias APGAR_5min Gastroschisis Omphalocele Cyanotic_Congenital_Heart_Disease Congenital_Diaphragmatic_Hernia Meningomyelocele Anencephaly
3330954 Y N 3630.0 N N N N N 9.0 N N N N N N
2050659 Y N 3835.0 N N N N N 9.0 N N N N N N
1802149 Y N 2925.0 N N N N N 9.0 N N N N N N
299168 Y N 2865.0 N N N N N 9.0 N N N N N N
1615386 Y N 2254.0 N N N N N 9.0 N N N N N N

String to Bool

In [10]:
orig_len = len(df)

BOOL_COLS = [
    "Infant_Living",
    "Admission_NICU",
    "Limb_Reduction_Defect",
    "Cleft_Lip",
    "Down_Syndrome",
    "Suspected_Chromosomal_Disorder",
    "Hypospadias",
    "Gastroschisis",
    "Omphalocele",
    "Cyanotic_Congenital_Heart_Disease",
    "Congenital_Diaphragmatic_Hernia",
    "Meningomyelocele",
    "Anencephaly"
]
mapping = {'N': False, 'Y': True, 'U': pd.NA, "C": True, "P": pd.NA}
df[BOOL_COLS] = df[BOOL_COLS].replace(mapping)
df.dropna(inplace = True)

for col in BOOL_COLS:
    try:
        df[col] = df[col].astype('boolean')
    except:
        print(col)

print("Non-Null Ratio: ", len(df)/orig_len)
Non-Null Ratio:  0.901625

Birth Weight g -> kg

In [11]:
df['Birth_Weight (kg)'] = df['Birth_Weight (g)'].apply(lambda x:x/1000)
df.drop(columns = ['Birth_Weight (g)'], inplace = True)
df
Out[11]:
Infant_Living Admission_NICU Limb_Reduction_Defect Cleft_Lip Down_Syndrome Suspected_Chromosomal_Disorder Hypospadias APGAR_5min Gastroschisis Omphalocele Cyanotic_Congenital_Heart_Disease Congenital_Diaphragmatic_Hernia Meningomyelocele Anencephaly Birth_Weight (kg)
3330954 True False False False False False False 9.0 False False False False False False 3.630
2050659 True False False False False False False 9.0 False False False False False False 3.835
1802149 True False False False False False False 9.0 False False False False False False 2.925
299168 True False False False False False False 9.0 False False False False False False 2.865
1615386 True False False False False False False 9.0 False False False False False False 2.254
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
1372627 False False False False False False False 1.0 False False False False False False 0.286
1115061 False True False False False False False 5.0 False False True True False False 2.335
1143164 False True False False False False False 6.0 False False False False False False 3.545
3262457 False False False False False False False 1.0 False False False False False False 0.370
918732 False True False False False False False 5.0 False False False False False False 2.450

14426 rows × 15 columns

Drawing Causal Graph

From the result above, we can infer the following dependency graph (with a lot of assumption about diseases that I am not an expert of):

In [12]:
causal_graph = """
digraph {
    Infant_Living;
    Admission_NICU;
    APGAR_5min;
    Gastroschisis;
    Omphalocele;
    Cyanotic_Congenital_Heart_Disease;
    Congenital_Diaphragmatic_Hernia;
    Weight[label="Birth_Weight (kg)"];
    
    Admission_NICU -> Infant_Living;
    
    APGAR_5min -> Admission_NICU;
    Weight -> Infant_Living;
    
    Gastroschisis -> Admission_NICU;
    Gastroschisis -> Infant_Living;
    Cyanotic_Congenital_Heart_Disease -> Admission_NICU;
    Cyanotic_Congenital_Heart_Disease -> Infant_Living;
    Congenital_Diaphragmatic_Hernia -> Admission_NICU;
    Congenital_Diaphragmatic_Hernia -> Infant_Living;
    Omphalocele -> Admission_NICU;
    Omphalocele -> Infant_Living;
}"""
In [13]:
model = dowhy.CausalModel(
    data=df.reset_index(drop = True),
    graph=causal_graph.replace("\n", " "),
    treatment="Admission_NICU",
    outcome="Infant_Living"
)
model.view_model(size=(12, 8))

Assumptions:

  • APGAR 5min, a scores for infant healthiness, is used to determine if a newborn needs to be admitted to NICU
  • Birth Weight, an indicator for pre-mature birth, causes infant death
  • The rest of the potential newborn diseases determines both NICU admission and death

Model Training

Identify Estimand

In [15]:
identified_estimand = model.identify_effect(
    method_name = "exhaustive-search", 
    proceed_when_unidentifiable=True,
)
print(identified_estimand)
Estimand type: EstimandType.NONPARAMETRIC_ATE

### Estimand : 1
Estimand name: backdoor
Estimand expression:
        d                                                                                                              ↪
─────────────────(E[Infant_Living|Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omph ↪
d[Admission_NICU]                                                                                                      ↪

↪          
↪ alocele])
↪          
Estimand assumption 1, Unconfoundedness: If U→{Admission_NICU} and U→Infant_Living then P(Infant_Living|Admission_NICU,Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omphalocele,U) = P(Infant_Living|Admission_NICU,Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omphalocele)

### Estimand : 2
Estimand name: iv
Estimand expression:
 ⎡                                                            -1⎤
 ⎢     d                      ⎛     d                        ⎞  ⎥
E⎢────────────(Infant_Living)⋅⎜────────────([Admission_NICU])⎟  ⎥
 ⎣d[APGAR₅ₘᵢₙ]                ⎝d[APGAR₅ₘᵢₙ]                  ⎠  ⎦
Estimand assumption 1, As-if-random: If U→→Infant_Living then ¬(U →→{APGAR_5min})
Estimand assumption 2, Exclusion: If we remove {APGAR_5min}→{Admission_NICU}, then ¬({APGAR_5min}→Infant_Living)

### Estimand : 3
Estimand name: frontdoor
No such variable(s) found!

Estimate Effect using Backdoor

In [38]:
estimate = model.estimate_effect(
    identified_estimand,
    method_name="backdoor.propensity_score_weighting",
    method_params = {"glm_family":sm.families.Binomial()}
)
print(estimate)
*** Causal Estimate ***

## Identified estimand
Estimand type: EstimandType.NONPARAMETRIC_ATE

### Estimand : 1
Estimand name: backdoor
Estimand expression:
        d                                                                                                              ↪
─────────────────(E[Infant_Living|Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omph ↪
d[Admission_NICU]                                                                                                      ↪

↪          
↪ alocele])
↪          
Estimand assumption 1, Unconfoundedness: If U→{Admission_NICU} and U→Infant_Living then P(Infant_Living|Admission_NICU,Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omphalocele,U) = P(Infant_Living|Admission_NICU,Congenital_Diaphragmatic_Hernia,Gastroschisis,Cyanotic_Congenital_Heart_Disease,Omphalocele)

## Realized estimand
b: Infant_Living~Admission_NICU+Congenital_Diaphragmatic_Hernia+Gastroschisis+Cyanotic_Congenital_Heart_Disease+Omphalocele
Target units: ate

## Estimate
Mean value: -0.39075432372614205

In [39]:
# Interpreting Results
interpretation = estimate.interpret(method_name="textual_effect_interpreter")
interpretation
Increasing the treatment variable(s) [Admission_NICU] from 0 to 1 causes an increase of -0.39075432372614205 in the expected value of the outcome [['Infant_Living']], over the data distribution/population represented by the dataset.

Refute

Common Cause

In [18]:
res_placebo=model.refute_estimate(
    identified_estimand, estimate,
    method_name="placebo_treatment_refuter", 
    show_progress_bar=True, 
    placebo_type="permute"
)
print(res_placebo)
Refute: Use a Placebo Treatment
Estimated effect:-0.39075432372614205
New effect:0.009697896999021228
p value:0.41999999999999993

Conclusion: The result that **

Conclusion

As stated at the beginning, one needs strong domain knowledge to utilize Causal ML effectively. This study's hypothesis turns out to be wrong in the first place. The fact that Admission_to_NICU correlates negatively with Infant_Living likely means that it is not a treatment for infant living in the first place.

In [49]:
import plotly.express as px

plot_df = df.groupby(["Admission_NICU", "APGAR_5min"])['Infant_Living'].mean().reset_index()
fig = px.line(
    plot_df,
    x="APGAR_5min", 
    y="Infant_Living", 
    color = "Admission_NICU",

)
fig.show()

From the chart above we can see that this looks very different from the generated case in 1 What is Causal Machine Learning (ML) . This indicates that the decision to admit a newborn to NICU is NOT dependent on APGAR. Other variables, crucial in determining the chance of living of a newborn, is not modelled in this study. This would require further research. Regardless, this is still a valuation learning experience.

In [ ]: